Skip to content

Fix bus error or segfault from roi_align with large batchsize#9441

Open
zy1git wants to merge 12 commits intopytorch:mainfrom
zy1git:issue-8206
Open

Fix bus error or segfault from roi_align with large batchsize#9441
zy1git wants to merge 12 commits intopytorch:mainfrom
zy1git:issue-8206

Conversation

@zy1git
Copy link
Contributor

@zy1git zy1git commented Mar 13, 2026

Summary
Bug: roi_align in torchvision crashes with a bus error/segfault on CPU or returns silently wrong (all-zero) results on CUDA when the total number of output elements exceeds INT_MAX (~2.1 billion). This is caused by 32-bit int overflow in index arithmetic within the C++ and CUDA kernels.

Root Cause: The kernels use int for composite index calculations like n × channels × pooled_width × pooled_height and pointer offsets like (roi_batch_ind × channels + c) × height × width. When these products exceed 2,147,483,647, the int wraps to a negative value, causing out-of-bounds memory access.

Example: FasterRCNN with batch_size=172 generates ~172,000 ROIs. The output index reaches 171,999 × 256 × 7 × 7 = 2,157,555,456 > INT_MAX, which matches the reporter's observed threshold exactly.

Fix: Promoted int to int64_t for all index, offset, and stride variables in the relevant files.

Test Plan
New overflow non-regression test
pytest test/test_ops.py::TestRoIAlign::test_roi_align_large_index -v

Existing tests — verify no regressions
pytest test/test_ops.py::TestRoIAlign -v

Fixes #8206

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9441

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f40a46d with merge base d7400a3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the cla signed label Mar 13, 2026
@zy1git zy1git marked this pull request as draft March 13, 2026 10:03
test/test_ops.py Outdated
output_bytes = n_rois * channels * pooled_h * pooled_w * 4 # float32
if output_bytes > 9 * 1024**3:
pytest.skip("Test requires ~9 GB of memory")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these values are statically defined. This if block is either always True or always False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I agree. I removed that part in the new commit.

test/test_ops.py Outdated
x = torch.rand(num_imgs, channels, height, width, dtype=torch.float32, device=device)
rois = torch.zeros(n_rois, 5, dtype=torch.float32, device=device)
except RuntimeError:
pytest.skip("Not enough memory to allocate test tensors")
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify tests aren't not being skipped on the CI. If they pass, remove the try/except, if they don't, we'll have to consider other strategies to test this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I verified that the tests are not being skipped on the CI and my devserver. I removed the try/except in the new commit.

template <typename T>
void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Member

@NicolasHug NicolasHug Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why nthreads needs to be int64_t? It should never need to be that large? If it's for integer comparison to not warn, we could just cast?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. nthreads in the backward kernel is grad.numel(), which equals n_rois × channels × pooled_h × pooled_w.

I changed to int nthreads in both .cpp and .cu files and ran the test (only the forward kernel test, no backward kernel test due to the large memory requirement). The CPU test passed but the CUDA test failed with all-zero output. The reason is that the CPU forward kernel doesn't use nthreads — it loops over n_rois separately, and the overflow is handled by the int64_t changes to index_n, index_n_c, and index. The CUDA forward kernel uses a flat loop with nthreads as the bound, so truncating to int caused nthreads to wrap to a negative value, making the loop condition immediately false and skipping all output computation — resulting in all-zero output.

The CPU backward kernel does use nthreads in the same flat-loop pattern as CUDA (for (int64_t index = 0; index < nthreads; ...)) and receives the same overflowing value via grad.numel(), so it needs int64_t for the same reason.

A backward-specific test would require large memory (output + grad_output + grad_input), which might be impractical for CI. Do we want to add one with a memory skip guard, or is the current forward-only test sufficient?

Copy link
Contributor Author

@zy1git zy1git Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for "If it's for integer comparison to not warn, we could just cast?", I think it isn't a warning issue. The problem is that the value could be actually large and gets truncated at the call site before the function body runs. In the author's reproducing example (batch_size=172, default 1000 proposals per image), nthreads is 172,000 × 256 × 7 × 7 = 2,157,568,000 > INT_MAX.

I added the backward kernel test in the latest commit. If I change int64_t nthreadsto int nthreads, the CPU backward test fails with all-zero gradients because nthreads gets truncated to a negative value and the loop never executes.

nthreads doesn't mean "number of threads" in this code, instead, it means "total number of output elements to process.", which could be very large.

Feel free to let me know if you have any questions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. For my own ref:

nthreads is the name for output_size in the cuda kernel:

roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
output_size,

where output size was already inferred as uint64 since size() returns uint64:

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;

nthreads was previously implicitly cast to int32 when passed to roi_align_forward_kernel_impl causing an overflow. Now, with the change made to the function signature, it's properly kept as int64.

@zy1git zy1git marked this pull request as ready for review March 20, 2026 08:27
test/test_ops.py Outdated

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_roi_align_large_index(self, device):
"""Regression test for https://github.com/pytorch/vision/issues/8206"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Regression test for https://github.com/pytorch/vision/issues/8206"""
"""Non-regression test for https://github.com/pytorch/vision/issues/8206"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the new commit.

test/test_ops.py Outdated

# Forward kernel test
assert result.shape == (n_rois, channels, pooled_h, pooled_w)
assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, don't specify anything beyond the assert, pytest is already good at showing the right thing. Here, the message doesn't help that much either.

Suggested change
assert result.abs().sum() > 0, "roi_align returned all zeros — likely an index overflow bug"
assert result.abs().sum() > 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the new commit.

pc.pos1 = static_cast<int64_t>(y_low) * width + x_low;
pc.pos2 = static_cast<int64_t>(y_low) * width + x_high;
pc.pos3 = static_cast<int64_t>(y_high) * width + x_low;
pc.pos4 = static_cast<int64_t>(y_high) * width + x_high;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check that these casts are needed?

claude says: y_low and y_high are pixel coordinates bounded by height, and width is the image width. y_low * width + x_low is at most height * width, which is the number of pixels in a single channel of a single image. That's not going to overflow int.

Please check every single other change in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. These casts are not needed and I fixed in the new commit. I checked other changes and fixed them accordingly.

template <typename T>
void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. For my own ref:

nthreads is the name for output_size in the cuda kernel:

roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
output_size,

where output size was already inferred as uint64 since size() returns uint64:

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;

nthreads was previously implicitly cast to int32 when passed to roi_align_forward_kernel_impl causing an overflow. Now, with the change made to the function signature, it's properly kept as int64.

Comment on lines +29 to +30
int64_t index_n =
static_cast<int64_t>(n) * channels * pooled_width * pooled_height;
Copy link
Contributor Author

@zy1git zy1git Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type and cast are good.

const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int64_t index_n_c =
index_n + static_cast<int64_t>(c) * pooled_width * pooled_height;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting could be removed in the new commit.

Comment on lines +192 to +195
int64_t n_stride,
int64_t c_stride,
int64_t h_stride,
int64_t w_stride) {
Copy link
Contributor Author

@zy1git zy1git Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int64_t is used to be consistent with grad.stride() which returns int64_t, avoiding implicit narrowing conversions. However, the original code used int and the change is not necessary for fixing the bug, so I will use int instead of int64_t.

int64_t c_stride,
int64_t h_stride,
int64_t w_stride) {
for (int64_t index = 0; index < nthreads; index++) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int64_t is needed.

T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
T* offset_grad_input = grad_input +
((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting is needed.

((static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width);

int output_offset = n * n_stride + c * c_stride;
int64_t output_offset = static_cast<int64_t>(n) * n_stride + c * c_stride;
Copy link
Contributor Author

@zy1git zy1git Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type and cast are good.

template <typename T>
__global__ void roi_align_forward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is good

const T* rois,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is good

const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
const T* offset_input = input +
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast is needed here.

template <typename T>
__global__ void roi_align_backward_kernel_impl(
int nthreads,
int64_t nthreads,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is good.

Comment on lines +218 to +221
int64_t n_stride,
int64_t c_stride,
int64_t h_stride,
int64_t w_stride,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types are good.

Comment on lines +253 to +254
const int64_t output_offset =
static_cast<int64_t>(n) * n_stride + c * c_stride;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type is good. The cast is technically redundant since n_stride is already int64_t and the multiplication would promote automatically. I kept it to make the 64-bit intent explicit.

int64_t h_stride,
int64_t w_stride,
const int64_t memory_span) {
CUDA_1D_KERNEL_LOOP_T(index, nthreads, int64_t) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is good.

Comment on lines +269 to +270
const int64_t input_offset =
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type and cast are good.

Comment on lines +82 to +84
int64_t index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_input = input +
(static_cast<int64_t>(roi_batch_ind) * channels + c) * height * width;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type and cast are good

for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
int64_t index = index_n_c + ph * pooled_width + pw;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bus error or segfault from roi_align with large batchsize

2 participants